{ "cells": [ { "cell_type": "markdown", "id": "ecae2048", "metadata": {}, "source": [ "# Multi-Head Attention on Decoder\n", "\n", "This notebook augments the Bahdanau decoder with **multi-head attention** while keeping the GRU encoder/decoder structure unchanged.\n", "It demonstrates how to project queries/keys/values per head, concatenate the attended context, and continue training on the\n", "MT French/English dataset using the exact training/evaluation loop from the Bahdanau notebook.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "05b6adb9", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 1, "id": "7fb8a98d", "metadata": {}, "outputs": [], "source": [ "import importlib\n", "import hw7\n", "importlib.reload(hw7)\n", "from hw7 import *\n", "from tsv_seq2seq_data import TSVSeq2SeqData \n", "import os" ] }, { "cell_type": "code", "execution_count": 2, "id": "2d010ab7", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-12-02T11:00:20.578186\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.7, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "# data_path = os.path.expanduser('~/Dropbox/CS6140/data/sentence_pairs_large.tsv')\n", "# data = TSVSeq2SeqData(\n", "# path=data_path,\n", "# batch_size=512,\n", "# num_steps=25,\n", "# min_freq=2,\n", "# val_frac=0.05,\n", "# test_frac=0.0,\n", "# sample_percent=1,\n", "# )\n", "\n", "data = d2l.MTFraEng(batch_size=128)\n", "\n", "# embed_size = 256\n", "# num_hiddens = 320 \n", "# num_blks = 3 \n", "# num_layers =3\n", "# dropout = 0.4 \n", "# num_heads = 4\n", "\n", "embed_size = 512\n", "num_hiddens = 512 \n", "num_blks = 3 \n", "num_layers =4\n", "dropout = 0.1 \n", "num_heads = 8\n", "\n", "encoder = d2l.Seq2SeqEncoder(len(data.src_vocab), embed_size, num_hiddens, num_layers, dropout)\n", "decoder = MultiHeadSeq2SeqDecoder(len(data.tgt_vocab), embed_size,\n", " num_hiddens, num_layers, num_heads=num_heads, dropout=dropout)\n", "model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab[''], lr=0.002)\n", "trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)\n", "trainer.fit(model, data)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "81c1d6a9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go . => ['', '!'], bleu,0.000\n", "i lost . => [\"j'ai\", 'perdu', '.'], bleu,1.000\n", "he's calm . => ['il', 'a', '', '.'], bleu,0.000\n", "i'm home . => ['je', 'suis', 'certain', '.'], bleu,0.512\n" ] } ], "source": [ "# engs = ['go .', 'i lost .', 'he\\'s calm .', 'i\\'m home .']\n", "# fras = ['va !', 'j\\'ai perdu .', 'il est calme .', 'je suis chez moi .']\n", "# preds, _ = model.predict_step(\n", "# data.build(engs, fras), d2l.try_gpu(), data.num_steps)\n", "# for en, fr, p in zip(engs, fras, preds):\n", "# translation = []\n", "# for token in data.tgt_vocab.to_tokens(p):\n", "# if token == '':\n", "# break\n", "# translation.append(token)\n", "# print(f'{en} => {translation}, bleu,'\n", "# f'{d2l.bleu(\" \".join(translation), fr, k=2):.3f}')\n", " \n", "# examples=engs\n", "# references=fras" ] }, { "cell_type": "code", "execution_count": 3, "id": "b64a14eb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "vamos . => we're go! | reference: go . BLEU: 0.000\n", "me perdi . => i kicked | reference: i got lost . BLEU: 0.000\n", "esta tranquilo . => this will will | reference: he is calm . BLEU: 0.000\n", "estoy en casa . => i'm am at | reference: i am at home . BLEU: 0.353\n", "donde esta el tren ? => where this rent runs | reference: where is the train ? BLEU: 0.000\n", "necesito ayuda urgente . => i need to help to | reference: i need urgent help . BLEU: 0.548\n", "ayer llovio mucho en la ciudad . => we swam a lot tourists the in the | reference: it rained a lot in the city yesterday . BLEU: 0.456\n", "los ninos estan jugando en el parque . => the in in the the | reference: the children are playing in the park . BLEU: 0.363\n", "ella quiere aprender a hablar ingles muy bien . => she wants to speak to very well well. well. well. | reference: she wants to learn to speak english very well . BLEU: 0.338\n", "cuando llegara el proximo tren a madrid ? => when i got the the to started to to | reference: when will the next train to madrid arrive ? BLEU: 0.000\n" ] } ], "source": [ "examples = ['vamos .', 'me perdi .', 'esta tranquilo .', 'estoy en casa .', 'donde esta el tren ?', 'necesito ayuda urgente .',\n", " 'ayer llovio mucho en la ciudad .', 'los ninos estan jugando en el parque .', 'ella quiere aprender a hablar ingles muy bien .',\n", " 'cuando llegara el proximo tren a madrid ?']\n", "\n", "references = ['go .', 'i got lost .', 'he is calm .', 'i am at home .', 'where is the train ?',\n", " 'i need urgent help .', 'it rained a lot in the city yesterday .',\n", " 'the children are playing in the park .', 'she wants to learn to speak english very well .', 'when will the next train to madrid arrive ?']\n", "\n", "preds, _ = model.predict_step(\n", " data.build(examples, references), d2l.try_gpu(), data.num_steps)\n", "for src, tgt, pred in zip(examples, references, preds):\n", " translation = []\n", " for token in data.tgt_vocab.to_tokens(pred):\n", " if token == '':\n", " break\n", " translation.append(token)\n", " \n", " hypo = ' '.join(translation)\n", " print(f\"{src} => {hypo} | reference: {tgt} BLEU: {d2l.bleu(hypo, tgt, k=2):.3f}\")\n", " " ] }, { "cell_type": "markdown", "id": "727fc606", "metadata": {}, "source": [ "Optional : Beam search instead of the greedy (armax) decoder : try to assign probabilites for entire sequence, before picking the output. Uses dynamic programing. " ] }, { "cell_type": "code", "execution_count": 8, "id": "fb1fae5e", "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "a Tensor with 9 elements cannot be converted to Scalar", "output_type": "error", "traceback": [ "\u001b[31m---------------------------------------------------------------------------\u001b[39m", "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)", "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 26\u001b[39m\n\u001b[32m 24\u001b[39m avg_weights = weights.mean(dim=\u001b[32m1\u001b[39m, keepdim=\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;66;03m# (batch, 1, num_steps, num_keys)\u001b[39;00m\n\u001b[32m 25\u001b[39m src_len = src_lens.item()\n\u001b[32m---> \u001b[39m\u001b[32m26\u001b[39m tgt_len = \u001b[43mtgt_lens\u001b[49m\u001b[43m.\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# decoder steps actually used\u001b[39;00m\n\u001b[32m 27\u001b[39m heatmap = avg_weights[:, :, :tgt_len, :src_len].cpu()\n\u001b[32m 29\u001b[39m d2l.show_heatmaps(\n\u001b[32m 30\u001b[39m heatmap, xlabel=\u001b[33m'\u001b[39m\u001b[33mKey positions\u001b[39m\u001b[33m'\u001b[39m, ylabel=\u001b[33m'\u001b[39m\u001b[33mQuery positions\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 31\u001b[39m )\n", "\u001b[31mRuntimeError\u001b[39m: a Tensor with 9 elements cannot be converted to Scalar" ] } ], "source": [ "\n", "\n", "\n", "jj=-1\n", "\n", "# Build a single batch for the last example/reference pair\n", "\n", "src, tgt_inputs, src_lens, tgt_lens = data.build([examples[jj]], [references[jj]])\n", "device = d2l.try_gpu()\n", "src, tgt_inputs, src_lens = src.to(device), tgt_inputs.to(device), src_lens.to(device)\n", "\n", "# Run encoder + decoder with teacher forcing to collect attention at every timestep\n", "enc_outputs = model.encoder(src, src_lens)\n", "dec_state = model.decoder.init_state(enc_outputs, src_lens)\n", "_ = model.decoder(tgt_inputs, dec_state) # forward pass stores weights on decoder.attention_weights\n", "\n", "dec_attention_weights = model.decoder.attention_weights # list length = number of decoder steps\n", "weights = torch.stack(dec_attention_weights) # (num_steps, batch*num_heads, 1, num_keys)\n", "num_steps, flat_batch, _, num_keys = weights.shape\n", "num_heads = model.decoder.attention.num_heads\n", "batch_size = flat_batch // num_heads\n", "\n", "weights = weights.view(num_steps, num_heads, batch_size, num_keys)\n", "weights = weights.permute(2, 1, 0, 3) # (batch, num_heads, num_steps, num_keys)\n", "\n", "# Average over heads, then crop to actual source/target lengths\n", "avg_weights = weights.mean(dim=1, keepdim=True) # (batch, 1, num_steps, num_keys)\n", "src_len = src_lens.item()\n", "tgt_len = tgt_lens.item() # decoder steps actually used\n", "heatmap = avg_weights[:, :, :tgt_len, :src_len].cpu()\n", "\n", "d2l.show_heatmaps(\n", " heatmap, xlabel='Key positions', ylabel='Query positions'\n", ")\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "4976201e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "vamos . => we're going to and the are all by the milky side of the dead. a lot. | reference: go . | BLEU: 0.000\n", "me perdi . => you have to pay me, and me. the one is getting a ride? the youngster is getting married. | reference: i got lost . | BLEU: 0.000\n", "esta tranquilo . => this sentence is in this way. and it is a bit of telling a lie. the same is a bit of the same. | reference: he is calm . | BLEU: 0.000\n", "estoy en casa . => i'm in the right now. i'm in the home perfectly. he has a lot of project. the right is getting better. | reference: i am at home . | BLEU: 0.000\n", "donde esta el tren ? => the one where to put this room. | reference: where is the train ? | BLEU: 0.000\n", "necesito ayuda urgente . => i need to need some extra money. i are more careful. | reference: i need urgent help . | BLEU: 0.089\n", "ayer llovio mucho en la ciudad . => it was very hot in the town. | reference: it rained a lot in the city yesterday . | BLEU: 0.332\n", "los ninos estan jugando en el parque . => the the leaves in the united states are in the end of the long ones. are always going to have to give it home and left. it is going to do. the length of the table. | reference: the children are playing in the park . | BLEU: 0.127\n", "ella quiere aprender a hablar ingles muy bien . => she wants to speak to speak to learn so much of the others. | reference: she wants to learn to speak english very well . | BLEU: 0.488\n", "cuando llegara el proximo tren a madrid ? => when i want to get the damned on the | reference: when will the next train to madrid arrive ? | BLEU: 0.000\n" ] } ], "source": [ "for src, tgt in zip(examples, references):\n", " src_sentence = src.lower().split()\n", " src_tokens = [data.src_vocab[token] for token in src_sentence]\n", " pred_ids = beam_search_translate(model, src_tokens, data, beam_size=5, max_steps=40)\n", " translation = data.tgt_vocab.to_tokens(pred_ids)\n", " hypo = ' '.join(translation)\n", " print(f\"{src} => {hypo} | reference: {tgt} | BLEU: {d2l.bleu(hypo, tgt, k=2):.3f}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python_mac_d2l", "language": "python", "name": "d2l" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.14.0" } }, "nbformat": 4, "nbformat_minor": 5 }